import numpy as np
import pickle
import os
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader


def convert_data_type_and_remove_redundant(data, dataset):
  # different naming mechanism, needs alignment
  if dataset in ['mosei_full_emo', 'mosei_full_senti']:
    data['vision'] = data.pop('FACET 4.2')
    data['text'] = data.pop('glove_vectors')
    data['audio'] = data.pop('COVAREP')
    data['labels'] = data.pop('All Labels')
  elif dataset == 'mosi_new':
    data['vision'] = data.pop('FACET_4.2')
    data['text'] = data.pop('glove_vectors')
    data['audio'] = data.pop('COVAREP')
    data['labels'] = data.pop('All Labels')

  data['vision'] = data['vision'].astype(np.float32)
  data['text'] = data['text'].astype(np.float32)
  data['audio'] = data['audio'].astype(np.float32)
  data['audio'][data['audio'] == -np.inf] = 0
  data['id'] = data['id'] if 'id' in data.keys() else None

  # change data type & remove redundant labels for iemocap
  if dataset == 'iemocap':
    data['labels'] = data[
        'labels'][:, :, 1]  # only one row is enough for binary classification
    data['labels'] = data['labels'].astype(np.int64)
  elif dataset in ['mosei_full_emo', 'mosei_full_senti']:
    if dataset == 'mosei_full_senti':
      data['labels'] = data['labels'][:, :, 0].astype(np.float32)
    else:  # 'mosei_full_emo'
      data['labels'] = data[
          'labels'][:, :,
                    1:]  # only emotion labels are extracted, [:,:,0] are the sentiment labels
      data['labels'] = (data['labels'] > 0).astype(
          np.int64)  # binary classification
  else:
    data['labels'] = data['labels'].astype(np.float32)
  return data


def get_data(data_folder, dataset, aligned=True):
  if dataset in ['mosei_full_emo', 'mosei_full_senti']:
    filename = 'mosei_full'
  else:
    filename = dataset
  dataset_path = os.path.join(
      data_folder,
      filename + '_data.pkl' if aligned else filename + '_data_noalign.pkl')
  # different naming mechanism, needs alignment
  if dataset in ['mosei_full_emo', 'mosei_full_senti', 'mosi_new']:
    data_wrongname = pickle.load(open(dataset_path, 'rb'))
    data = {}
    data['train'] = data_wrongname[0]
    data['valid'] = data_wrongname[1]
    data['test'] = data_wrongname[2]
    del data_wrongname
  else:
    data = pickle.load(open(dataset_path, 'rb'))
  return convert_data_type_and_remove_redundant(
      data['train'], dataset), convert_data_type_and_remove_redundant(
          data['valid'], dataset), convert_data_type_and_remove_redundant(
              data['test'], dataset)


def get_mean_and_std(data):
  mean = data.mean((0, 1))
  std = (data - mean).std((0, 1))
  # since there are some features which is always 0, change std in order to fix that.
  std[std == 0.] = 1.
  return mean, std


def get_mean_and_std_3_modal(data):
  mean_a, std_a = get_mean_and_std(data['audio'])
  mean_v, std_v = get_mean_and_std(data['vision'])
  mean_l, std_l = get_mean_and_std(data['text'])
  return [mean_a, std_a, mean_v, std_v, mean_l, std_l]


def show_statistics(statistics):
  names = ['mean_a', 'std_a', 'mean_v', 'std_v', 'mean_l', 'std_l']
  for name, item in zip(names, statistics):
    item.sort()
    print(f'----{name}')
    print(item)


def norm_3_modal(data, statistics):
  data['audio'] = (data['audio'] - statistics[0]) / statistics[1]
  data['vision'] = (data['vision'] - statistics[2]) / statistics[3]
  data['text'] = (data['text'] - statistics[4]) / statistics[5]
  return data


def concat_data(data1, data2):
  data = {}
  data['audio'] = np.concatenate((data1['audio'], data2['audio']), axis=0)
  data['text'] = np.concatenate((data1['text'], data2['text']), axis=0)
  data['vision'] = np.concatenate((data1['vision'], data2['vision']), axis=0)
  data['labels'] = np.concatenate((data1['labels'], data2['labels']), axis=0)
  return data


class AVL_Dataset(Dataset):

  def __init__(self, data, dataset):
    super(AVL_Dataset, self).__init__()
    # show statistics
    # statistics = get_mean_and_std_3_modal(data)
    # show_statistics(statistics)
    # abnormal feature in audio, since this feature always >0, i did this
    # keep train/eval/test dataset using the same preproessing coefficients.

    # feature scaling. dont know the effect. nothing for mosi
    if dataset in [
        'iemocap', 'mosei_senti', 'mosei_full_emo', 'mosei_full_senti',
        'mosi_new'
    ]:
      data['audio'][:, :, 0] /= 500.0

    # These are numpy ndarrays
    self.vision = data['vision']
    self.text = data['text']
    self.audio = data['audio']
    self.labels = data['labels']
    # Note: this is STILL an numpy array
    self.meta = data['id'] if 'id' in data.keys() else None

    self.n_modalities = 3  # vision/ text/ audio

  def get_n_modalities(self):
    return self.n_modalities

  def get_seq_len(self):
    return self.text.shape[1], self.audio.shape[1], self.vision.shape[1]

  def get_dim(self):
    return self.text.shape[2], self.audio.shape[2], self.vision.shape[2]

  def get_lbl_info(self):
    # return number_of_labels, label_dim
    return self.labels.shape[1], self.labels.shape[2]

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, index):
    return self.text[index], self.audio[index], self.vision[index], self.labels[
        index]


def get_data_loaders(data_folder,
                     dataset,
                     batch_size,
                     aligned=True,
                     use_cuda=True,
                     num_workers=0):

  train_data, valid_data, test_data = get_data(
      data_folder, dataset, aligned=aligned)

  train_dataset = AVL_Dataset(train_data, dataset)
  valid_dataset = AVL_Dataset(valid_data, dataset)
  test_dataset = AVL_Dataset(test_data, dataset)

  train_loader = DataLoader(
      train_dataset,
      batch_size=batch_size,
      shuffle=True,
      drop_last=True,
      pin_memory=use_cuda,
      num_workers=num_workers,
      prefetch_factor=batch_size * 2 // num_workers + 1)
  valid_loader = DataLoader(
      valid_dataset,
      batch_size=batch_size,
      shuffle=False,
      drop_last=False,
      pin_memory=use_cuda,
      num_workers=num_workers,
      prefetch_factor=batch_size * 2 // num_workers + 1)
  test_loader = DataLoader(
      test_dataset,
      batch_size=batch_size,
      shuffle=False,
      drop_last=False,
      pin_memory=use_cuda,
      num_workers=num_workers,
      prefetch_factor=batch_size * 2 // num_workers + 1)
  return train_loader, valid_loader, test_loader


def get_datasets(data_folder, dataset, aligned=True):

  train_data, valid_data, test_data = get_data(
      data_folder, dataset, aligned=aligned)

  train_dataset = AVL_Dataset(train_data, dataset)
  valid_dataset = AVL_Dataset(valid_data, dataset)
  test_dataset = AVL_Dataset(test_data, dataset)
  return train_dataset, valid_dataset, test_dataset


if __name__ == "__main__":
  data_folder = '/home/asteria/Multimodal-Transformer/data'
  dataset = 'mosei_full_senti'
  batch_size = 32
  aligned = True
  use_cuda = True
  train_dataset, valid_dataset, test_dataset = get_datasets(
      data_folder, dataset, aligned=aligned)
  train_loader, valid_loader, test_loader = get_data_loaders(
      data_folder, dataset, batch_size, aligned=aligned, use_cuda=use_cuda)
  for x in test_loader:
    print(x)